{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Tiled matmul\n",
    "\n",
    "In this simple tutorial we will see how to perform a `matmul` with tiling.\n",
    "Tiling is a technique based on matrix partition, each block is called a tile.\n",
    "\n",
    "With tiling, `matmul`:\n",
    "* computation can be performed in parallel, a domain where GPUs excels;\n",
    "* global memory (GM) access are limited, GM access being the GPU bottleneck (compared to computation).\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "M, N, K = 15, 9, 12\n",
    "\n",
    "A = torch.rand((M, K))\n",
    "B = torch.rand((K, N))"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Simple matmul with tiling\n",
    "\n",
    "Simple example showing how we can perform a `matmul` through tiling.\n",
    "\n",
    "Basic introduction to the subject can be found here:\n",
    "\n",
    "* https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html\n",
    "* https://penny-xu.github.io/blog/tiled-matrix-multiplication\n",
    "\n",
    "Parallelization can be applied at each `M` and `N` for loop levels.\n",
    "However, best use of global memory access requires to be a bit smarter.\n",
    "Check our dedicated explanation in tutorials.\n",
    "\n",
    "Values used below are a arbitrary and small to be printable if needed.\n",
    "Rule of thumb in defining tile shape is:\n",
    "* large tile size increase data reuse, but decrease thread-level parallelism;\n",
    "* small tile size increase thread-level parallelism but reduce data reuse."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total load from GM: 864\n",
      "total write to GM: 135\n"
     ]
    }
   ],
   "source": [
    "# for simplification tile shapes are all multiple of matrix shapes\n",
    "# otherwise we would need to check matrix bounds and mask out of bounds values by 0s in tiles\n",
    "block_M, block_N, block_K = M // 3, N // 3, K // 2\n",
    "\n",
    "output = torch.zeros((M, N))\n",
    "total_load = 0\n",
    "total_write = 0\n",
    "\n",
    "for index_M in range(0, M, block_M):\n",
    "    start_M = index_M\n",
    "    end_M = index_M + block_M\n",
    "\n",
    "    for index_N in range(0, N, block_N):\n",
    "        start_N = index_N\n",
    "        end_N = index_N + block_N\n",
    "        accumulator = torch.zeros((block_M, block_N))\n",
    "        for index_K in range(0, K, block_K):\n",
    "            start_K = index_K\n",
    "            end_K = index_K + block_K\n",
    "\n",
    "            tile_A = A[start_M:end_M, start_K:end_K]\n",
    "            total_load += tile_A.numel()\n",
    "            tile_B = B[start_K:end_K, start_N:end_N]\n",
    "            total_load += tile_B.numel()\n",
    "            # @ means matmul in numpy and pytorch\n",
    "            accumulator += tile_A @ tile_B\n",
    "        output[start_M:end_M, start_N:end_N] = accumulator\n",
    "        total_write += accumulator.numel()\n",
    "\n",
    "assert torch.allclose(output, A @ B)\n",
    "print(\"total load from GM:\", total_load)\n",
    "print(\"total write to GM:\", total_write)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "In the code above, we have tracked the quantity of global memory (GPU DRAM) load and write.\n",
    "Guessing the quantity of data written is quite obvious, it's the number of elements inside the output matrix, so\n",
    "`MxN`."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "data": {
      "text/plain": "135"
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "M * N"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Regarding the loading, it is: (tile A shape + tile B shape) repeated on each `M`, `N`, `K` axis."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "data": {
      "text/plain": "864.0"
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((block_M * block_K) + (block_K * block_N)) * (K / block_K) * (N / block_N) * (M / block_M)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "You can note that if you make `block_N` and `block_M` smaller, it will increase the number of readings."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}